SAM 2 Annotation Tool¶
In this notebook, I walk through a user-friendly tool I created that allows you to accurately label a video for object tracking tasks.¶
The tool annotates the video by passing it through Meta's SAM 2 model and allowing a human-in-the-loop to correct its mistakes. SAM 2 is specifically designed for such a use case, as it is a promptable visual segmentation (PVS) model. Thus, before any object can be tracked, it must be identified in a given frame with a point(s), a bounding box, or a mask. After the initial prompt, SAM 2 will then track the object(s) throughout the video. If a given masklet is lost (e.g., from an occlusion), SAM 2 will require a new prompt in order to regain it.¶
SAM 2's transformer-based architecture learns both motion- and appearance-based features, outperforming many of the top existing tracker models. Its promptable nature also makes it especially well-suited for providing initial high-fidelity labels that can be further refined with just a few clicks.¶
Steps
1) Load in SAM 2 and the video
2) Set the initial prompt
3) Run inference with SAM 2
4) Find the frame(s) with an object whose masklet was lost
5) Re-label said frame(s) to regain the masklet
6) Re-run inference with the correction
7) Output the final labeled data
1) Load in SAM 2 and the video
2) Set the initial prompt
3) Run inference with SAM 2
4) Find the frame(s) with an object whose masklet was lost
5) Re-label said frame(s) to regain the masklet
6) Re-run inference with the correction
7) Output the final labeled data
In [2]:
import os
import shutil
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from io import BytesIO
from matplotlib.widgets import Slider, Cursor, Button, TextBox
from matplotlib.patches import Rectangle
import ipywidgets as widgets
from IPython.display import display
Set up the environment¶
In [3]:
%%capture
!git clone https://github.com/facebookresearch/segment-anything-2.git
os.chdir('/segment-anything-2')
!pip install -e .
!./checkpoints/download_ckpts.sh
!python setup.py clean
!python setup.py build_ext --inplace
In [1]:
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
Load In SAM 2 and the Video¶
First load in an instance of the SAM 2 predictor, choosing from its "tiny", "small", or "large" versions.¶
In [6]:
from sam2.build_sam import build_sam2_video_predictor
model_size = "tiny" # Set to: 'tiny', 'small', or 'large'
sam2_checkpoint = f"sam2_hiera_{model_size}.pt"
model_cfg = f"sam2_hiera_{model_size[0]}.yaml"
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
The functions below allow us to visualize SAM 2's prompts and output. More specifically, they annotate the video frames with the user-selected points, the resulting segmented mask, and the implied bounding box.¶
In [2]:
# Function to draw a mask (and its implied bounding box)
def show_mask(mask, ax, obj_id=None, random_color=False):
# Generate a color for the mask
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
cmap = plt.get_cmap("tab10")
cmap_idx = 0 if obj_id is None else obj_id
color = np.array([*cmap(cmap_idx)[:3], 0.6])
# Plot the mask
h, w = mask.shape[-2:]
if mask is not None and np.any(mask):
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
# Plot the bounding box
x_min, y_min, x_max, y_max = mask_to_bb(mask)
rect = Rectangle((x_min, y_min), x_max - x_min, y_max - y_min, linewidth=2, edgecolor=color, facecolor='none')
ax.add_patch(rect)
ax.imshow(mask_image)
# Function to display the selected points for labeling
def show_points(coords, labels, ax, marker_size=200):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
# Function to derive a bounding box from a mask
def mask_to_bb(mask):
if mask is not None and np.any(mask):
rows, cols = np.where(mask.squeeze())
y_min, y_max = rows.min(), rows.max()
x_min, x_max = cols.min(), cols.max()
xyxy = (x_min, y_min, x_max, y_max)
return xyxy
Store the video as a list of JPEG frames with filenames like <frame_index>.jpg.¶
In [4]:
def load_video(input_video, output_folder, reshape_scale=1.0):
# Ensure the output folder exists
if not os.path.exists(output_folder):
os.makedirs(output_folder)
# Open the video file
cap = cv2.VideoCapture(input_video)
# Check if the video opened successfully
if not cap.isOpened():
raise ValueError("Error opening video file")
# Get video properties
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = cap.get(cv2.CAP_PROP_FPS)
# Loop through all frames
frame_number = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Size down
height, width = frame.shape[:2]
new_dim = (int(width * reshape_scale), int(height * reshape_scale))
resized_frame = cv2.resize(frame, new_dim, interpolation=cv2.INTER_AREA)
# Construct filename with leading zeros
filename = os.path.join(output_folder, f'{frame_number:05d}.jpg')
# Save the frame as an image
cv2.imwrite(filename, resized_frame)
frame_number += 1
# Release the video capture object
cap.release()
print("Frames have been extracted and saved.")
In [5]:
# Define input and output paths
input_video = 'marshawn.mp4'
# `video_dir` a directory of JPEG frames with filenames like `<frame_index>.jpg`
video_dir = 'marshawn_frames'
# Load in the video for SAM2 processing
load_video(input_video, video_dir, 1)
Frames have been extracted and saved.
In [6]:
# Scan all the JPEG frame names in this directory
frame_names = [
p for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
Initialize the inference state with the video frames, from which SAM 2 will run segmentation inference.¶
In [12]:
inference_state = predictor.init_state(video_path=video_dir)
frame loading (JPEG): 100%|██████████| 239/239 [00:12<00:00, 18.91it/s]
Set with the initial prompts¶
The function below creates a user-friendly interactive window for selecting the points on the object(s) that you want SAM 2 to track¶
In [19]:
def annotate(frame_idx):
# Initialize data structure
global data
data = []
# Load the image
image_path = os.path.join(video_dir, frame_names[frame_idx])
img = Image.open(image_path)
# Set up the plot
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_title(f"Select Points and Assign Labels: Frame {frame_idx}")
im = ax.imshow(img)
ax.set_xlim(0, img.width)
ax.set_ylim(img.height, 0) # Invert y-axis to match image coordinates
if 'video_segments' in globals() and frame_idx in video_segments:
for out_obj_id, out_mask in video_segments[frame_idx].items():
if out_mask is not None and np.any(out_mask):
show_mask(out_mask, ax, obj_id=out_obj_id)
# Initialize lists to store points and labels for the current object
global current_points, current_labels, current_obj_id, current_label
current_points = []
current_labels = []
current_obj_id = 1
current_label = 1
# Function to add points and plot them
def on_click(event):
if event.inaxes == ax:
x, y = int(event.xdata), int(event.ydata)
current_points.append([x, y])
current_labels.append(current_label)
ax.text(x, y, str(current_obj_id), color='blue' if current_label == 1 else 'red', fontsize=12, ha='center')
fig.canvas.draw()
print(f"Point added: ({x}, {y}) with label {current_label} for object {current_obj_id}")
# Update data with each click
existing_obj = next((item for item in data if item[0] == current_obj_id), None)
if existing_obj:
existing_obj[1].append([x, y])
existing_obj[2].append(current_label)
else:
data.append((current_obj_id, [[x, y]], [current_label]))
print(f"Data updated for object {current_obj_id}: {data}")
# Connect the click event to the callback function
cid = fig.canvas.mpl_connect('button_press_event', on_click)
# Function to update labels
def update_label(change):
global current_label
current_label = int(change['new'])
print(f"Current label set to {current_label}")
# Create dropdown for label selection using ipywidgets
label_dropdown = widgets.Dropdown(
options=[1, 0],
value=current_label,
description='Label:',
)
label_dropdown.observe(update_label, names='value')
display(label_dropdown)
# Function to update the object ID
def update_obj_id(change):
global current_obj_id
save_current_object_data()
current_obj_id = int(change['new'])
print(f"Switched to object ID {current_obj_id}")
# Function to save current object data
def save_current_object_data():
global current_points, current_labels, data, current_obj_id
if current_points and current_labels:
existing_obj = next((item for item in data if item[0] == current_obj_id), None)
if existing_obj:
existing_obj[1].extend(current_points)
existing_obj[2].extend(current_labels)
else:
data.append((current_obj_id, current_points.copy(), current_labels.copy()))
print(f"Data saved for object {current_obj_id}: {data}")
current_points.clear()
current_labels.clear()
# Create dropdown for object ID selection using ipywidgets
object_id_dropdown = widgets.Dropdown(
options=[i for i in range(1, 51)],
value=current_obj_id,
description='Object ID:',
)
object_id_dropdown.observe(update_obj_id, names='value')
display(object_id_dropdown)
plt.show()
The two dropdowns allow the user to:
1) Select the
2) Select the
1) Select the
Label of the clicked point (1 = point is on the object; 0 = point is not on the object)2) Select the
Object ID to distinguish multiple objects/masks from one another throughout the video.
In [1]:
frame_idx = 0
annotate(frame_idx)

These functions transform the selected points into the prompt structure with which SAM 2's predictor object is updated.¶
In [4]:
# Function to structure the selected points into prompts for SAM 2
def make_prompts(data: list):
"""
Inputs:
- data (list of tuples): data on the objects to be tracked, with each tuple formatted as
(object_id, [[x1, y1], [x2, y2], ...], [label1, label2, ...])
Outputs:
- prompts: a dict with all the visual prompt information for SAM2
"""
prompts = {}
for obj_id, points, labels in data:
prompts[obj_id] = (
np.array(points, dtype=np.float32),
np.array(labels, np.int32)
)
return prompts
# Function to add the prompts to the SAM 2 predictor
def add_prompts(prompts, frame_idx, inference_state, is_refinement=False):
if is_refinement:
predictor.reset_state(inference_state)
# Iterate over each object in the prompts dictionary
for obj_id, (points, labels) in prompts.items():
# Call the function to add new points for each object
_, out_obj_ids, out_mask_logits = predictor.add_new_points(
inference_state=inference_state,
frame_idx=frame_idx,
obj_id=obj_id,
points=points,
labels=labels
)
return _, out_obj_ids, out_mask_logits
In [21]:
# Process the user-selected labeled points
prompts = make_prompts(data)
_, out_obj_ids, out_mask_logits = add_prompts(prompts, frame_idx, inference_state)
In [23]:
# Show the results on the current (interacted) frame
plt.figure(figsize=(12, 8))
plt.title(f"frame {frame_idx}")
plt.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
show_mask(mask, plt.gca(), obj_id=out_obj_id)
Run Inference¶
In [39]:
# Function to propagate the masks throughout the video
def propagate_masks(inference_state, video_segments={}, start_frame_idx=0):
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
return video_segments
In [40]:
video_segments = propagate_masks(inference_state, start_frame_idx=15)
propagate in video: 100%|██████████| 239/239 [00:09<00:00, 26.32it/s]
In [ ]:
# Function to visualize the outputs of the SAM 2 model
def view_labeled_frames(frame_stride, frame_names, video_segments, video_dir=video_dir):
plt.close("all")
frame_stride = 15
for out_frame_idx in range(0, len(frame_names), frame_stride):
fig, ax = plt.subplots(figsize=(6, 4))
ax.set_title(f"frame {out_frame_idx}")
ax.imshow(Image.open(os.path.join(video_dir, frame_names[out_frame_idx])))
if out_frame_idx in video_segments:
for out_obj_id, out_mask in video_segments[out_frame_idx].items():
if out_mask is not None and np.any(out_mask):
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
display(fig)
plt.close(fig)
View the Output and Find Mistakes¶
In [42]:
# Render the segmentation results every 15 frames
frame_stride = 15
view_labeled_frames(frame_stride, frame_names, video_segments, video_dir)
As seen in the frames above, SAM 2 does a great job of maintaining Marshawn Lynch's masklet, even as he runs through defenders. However, in frame 210, we notice that the mask is reduced to only the crown of his helmet, even though his legs are seen in the air. Because of this, by the time he emerges victorious in frame 225, the masklet has been completely lost.¶
The promptable nature of SAM 2 does not only allow us to instantiate masks, but also to refine its predictions at any point during the video. Therefore, we will provide the correct mask for frame 210 by selecting a new set of labeled points. Thus, SAM 2 can recalibrate the masklet that it will propagate and attain even higher accuracy.¶
Re-label Faulty Masks¶
In [ ]:
frame_idx = 210
annotate(frame_idx)

In [45]:
prompts = make_prompts(data)
In [46]:
# Show the segment before further refinement
fig_before, ax_before = plt.subplots(figsize=(12, 8))
ax_before.set_title(f"Frame {frame_idx} -- Before Refinement")
ax_before.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for out_obj_id, out_mask in video_segments[frame_idx].items():
if out_mask is not None and np.any(out_mask):
show_mask(out_mask, plt.gca(), obj_id=out_obj_id)
display(fig_before)
plt.close(fig_before)
# Iterate over each object in the prompts dictionary and process the points and labels
_, out_obj_ids, out_mask_logits = add_prompts(prompts, frame_idx, inference_state, is_refinement=True)
# Show the segment after further refinement
fig_after, ax_after = plt.subplots(figsize=(12, 8))
ax_after.set_title(f"Frame {frame_idx} -- After Refinement")
ax_after.imshow(Image.open(os.path.join(video_dir, frame_names[frame_idx])))
for i, out_obj_id in enumerate(out_obj_ids):
show_points(*prompts[out_obj_id], plt.gca())
mask = (out_mask_logits[i] > 0.0).cpu().numpy()
show_mask(mask, plt.gca(), obj_id=out_obj_id)
display(fig_after)
plt.close(fig_after)
Now SAM 2 has captured a much better masklet for frame 210, which it will propagate henceforth.¶
Output Final Results¶
In [47]:
video_segments = propagate_masks(inference_state, video_segments, start_frame_idx=frame_idx)
propagate in video: 100%|██████████| 29/29 [00:01<00:00, 28.04it/s]
In [50]:
view_labeled_frames(15, frame_names, video_segments, video_dir)